{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "mus = [[0, 4], [-2, 0], [4, 4]]\n",
    "\n",
    "sigs = [[[3, 0], [0, 0.5]], [[1, 0], [0, 2]], [[1, 0], [0,1]]]\n",
    "\n",
    "\n",
    "# 数据集X\n",
    "multi_data = [np.random.multivariate_normal(mu, sig, 100) for mu, sig in zip(mus, sigs)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data = np.vstack(multi_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "## 数据集的类别标签为0，1，2\n",
    "labels = [[value]*100 for value in range(3)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]]\n"
     ]
    }
   ],
   "source": [
    "print labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "## 把嵌套数组labels拉平\n",
    "multi_labels = [y for x in labels for y in x]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n"
     ]
    }
   ],
   "source": [
    "print multi_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 数据集标签合并\n",
    "final_data = []\n",
    "\n",
    "for v1, v2 in zip(data, multi_labels):\n",
    "    list_v1 = v1.tolist()\n",
    "    list_v1.append(v2)\n",
    "    final_data.append(list_v1) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[3.5593897975906197, 3.736527009372364, 0], [0.6382892112881816, 3.588033841365842, 0], [0.2632863000119277, 3.6418791057870057, 0], [-3.0806583652946378, 4.313036328113317, 0], [-2.946973346565818, 4.2007797192817495, 0], [0.859832381065888, 4.062416658541541, 0], [2.0803735950949687, 3.9420024711310737, 0], [2.253177787194711, 5.114644685392081, 0], [-1.227059694348927, 4.312979802762684, 0], [2.1320743995469598, 5.424193031378502, 0], [2.5263897404541424, 4.2863931066916505, 0], [-1.6023973289952056, 5.098031584492405, 0], [0.0009144524766546955, 3.4883492002658114, 0], [1.4526835758186951, 4.502273717152065, 0], [-2.080134463932432, 4.6618379820484686, 0], [-0.7376472230709924, 5.803097007130773, 0], [-2.253314351369113, 3.953951563079367, 0], [0.3939423800045117, 3.6943452523501166, 0], [3.0968048034125757, 3.6806107244228894, 0], [2.289116415364021, 3.9870101671111535, 0], [0.6264737578152862, 3.741794266275358, 0], [-0.7864902253465736, 4.199606423310847, 0], [0.2759101397369275, 2.5995246773954275, 0], [-0.11573062205373608, 3.480755568518185, 0], [0.3630595564901449, 3.8193121848730667, 0], [-2.3541297278981905, 4.617919747329587, 0], [-0.1393500485857051, 4.169106167012477, 0], [-1.229676620128568, 4.378887070121326, 0], [-0.24222779307396392, 3.392477267161001, 0], [1.6999871219014488, 3.0271119638873043, 0], [-0.6467954979182303, 3.312358884754359, 0], [-2.0938587928573447, 4.607698040503408, 0], [-1.3993130309070427, 4.026755341718367, 0], [-1.276376069262216, 3.6380032667868294, 0], [0.42299761560666266, 4.613648175125713, 0], [2.2344063935391776, 4.463233724669758, 0], [-3.21317704887337, 4.267090754033438, 0], [-1.8585559514218042, 2.9229067964579536, 0], [-0.0797236171908722, 4.148615149643331, 0], [0.583212279631955, 2.580110125461964, 0], [-1.26937928167189, 3.9366603289412025, 0], [-0.14490604294007797, 4.061320166617641, 0], [-1.11562789887616, 4.271336233198918, 0], [-0.9363164618457444, 3.938951342152362, 0], [-0.20741064108600205, 3.868701095470327, 0], [0.718229805430768, 4.310576411147394, 0], [0.5252337325015385, 4.488056876575435, 0], [-1.1598230851730404, 3.1633666849826403, 0], [-1.6702576184489226, 4.682037551900045, 0], [0.4797272101616133, 5.296738129193235, 0], [0.9603014627340807, 4.06899014584137, 0], [-2.1749244254585474, 4.680191769512619, 0], [0.7055929486208558, 4.310387049225572, 0], [-1.6273887432371903, 5.505535694905153, 0], [0.03566039932018928, 4.032797732705941, 0], [0.5469478265638494, 3.533230634808844, 0], [-0.6569063492545735, 4.262380962745065, 0], [4.873477831427014, 3.911043062660721, 0], [1.8884373234206377, 3.413593905948643, 0], [-1.7376573738498051, 4.0222670151939015, 0], [1.1346679230545944, 4.2453793112648475, 0], [-1.1067780728998662, 3.9078285367195775, 0], [-0.5578559773113465, 4.505703799310029, 0], [-1.4704502157593242, 3.9452476509906944, 0], [-0.8936085001831094, 4.973320112542332, 0], [1.246837949790526, 3.7406124743448106, 0], [0.695391819473179, 3.215383688084266, 0], [-3.968262247197826, 4.100680107599374, 0], [2.290300231330779, 3.3096653562533285, 0], [-0.3274045137736968, 4.339044033478316, 0], [-2.2969780465391754, 3.606601458530838, 0], [-1.2956230400219464, 3.702464792606867, 0], [-3.426929548555545, 4.731833142875785, 0], [-1.24376997962011, 4.587481260319372, 0], [-2.1057356340089886, 4.083206923044556, 0], [1.9159750400662772, 4.322823105875789, 0], [-1.0254931176690707, 3.6805312273656683, 0], [-0.6563091088989254, 5.051006292393634, 0], [-2.039116321391051, 2.949402971986678, 0], [-0.48399109180025357, 4.016821162942991, 0], [-0.2304116223343647, 4.698521110136274, 0], [2.0604561909115473, 4.936152528359206, 0], [-1.8125482079006536, 4.696281039511931, 0], [-2.7348217180204775, 2.522143850147877, 0], [0.3616868451134729, 4.461186774668471, 0], [-2.7642373882075035, 3.715354172088658, 0], [0.24511472118101468, 3.1712311787937524, 0], [-0.5158730553503137, 4.3880572535000795, 0], [1.755091160248367, 1.960071082702421, 0], [0.24295355936942525, 3.7690670506687525, 0], [2.8456608218091612, 3.274257814811507, 0], [-0.4329048612781419, 3.5574305717396104, 0], [0.642013224293857, 4.629178721936455, 0], [2.383292899572844, 5.017411703106532, 0], [0.5121696913796528, 4.71794986611183, 0], [2.1584434273177484, 3.382361255485536, 0], [1.3915330382176745, 3.609610779138298, 0], [-0.48303865121086914, 2.9512899473322833, 0], [-0.17410428062495628, 4.3287683736566125, 0], [-0.8946735160243385, 4.772101522859985, 0], [-1.379386248249244, 0.77658975429236, 1], [-1.8803827786671055, 0.11056830248492576, 1], [-2.3456293811887794, 1.065115954289983, 1], [-2.9033226684593885, -2.3568402619862, 1], [-3.33873248655075, -0.08069373810368652, 1], [-0.7762488068320517, 1.6847176334540306, 1], [-0.7552911522794037, -1.366119654210927, 1], [-2.4558274876205752, 0.268352523555578, 1], [-4.408400952016439, 1.7771350320195927, 1], [-0.4569849483877815, -2.0272409607900634, 1], [-1.0448773708191637, 0.1889974603839339, 1], [-2.8744247000521312, -1.6356078277295545, 1], [-3.447420227297748, 2.121301402055159, 1], [-1.8843533072735523, -0.8419113804803363, 1], [-1.9341954949245357, -3.2471436673481913, 1], [-1.5533586164818436, -2.0185339460170533, 1], [-3.1302980001607676, -0.3373076001846845, 1], [-4.201020759276739, -1.5002038235241208, 1], [-1.0490068844169445, 2.951040566785435, 1], [-2.7716362876242373, 2.1011645644950923, 1], [-2.8444606016435223, 0.9313394387612425, 1], [-2.8908749024996423, -0.1546427282985843, 1], [-1.110290442688222, -1.384521718778183, 1], [-0.34859519219417745, -0.7259618415461366, 1], [-2.14918281232061, -0.05909182563279221, 1], [-0.6380925086824782, 0.029563983412872433, 1], [-1.9732767789212167, -2.657157583190173, 1], [-0.5552360764294617, 0.2721038194690457, 1], [-2.972233246107584, 2.4394352282994163, 1], [0.5626346758673177, 0.869462093590307, 1], [-1.7048565706654364, -0.46715997941076914, 1], [-1.1728671749338666, 2.1952494264593727, 1], [-3.4734488003805533, -2.645835042677507, 1], [-2.3863110549753124, 0.5397482546093665, 1], [-2.5517564445509646, -1.2321718169022546, 1], [0.31344772269330834, -1.494449835471386, 1], [-0.23028120313149847, -2.3693570257731844, 1], [-1.964489469838514, 0.7689148007239601, 1], [-2.417695799060901, -0.3054211566559618, 1], [-1.2206991377198881, 2.8968506738340065, 1], [-1.3751452747477573, 2.566603177341791, 1], [-1.552126868409797, 0.30242003040005017, 1], [-0.8925301247280577, -0.7168706709934206, 1], [-1.6407126032585873, -1.732069813433546, 1], [-1.6985413150142163, -0.44526191458764275, 1], [-2.01753272907842, -0.07443334976717048, 1], [-0.4956626365384753, -0.5175352324112511, 1], [-3.1319494772141816, 1.0415140880284999, 1], [-2.7964090769620955, -1.238490430319433, 1], [-3.396610989413893, -0.481025839655896, 1], [-1.963552865256177, -1.1011084371028548, 1], [-0.8820409969261933, -0.5454598952166543, 1], [-2.5407860290347144, -0.6199577192431571, 1], [-1.6591661687572872, -0.05823522853658854, 1], [-1.2618910141262774, -0.5864298950798029, 1], [-2.182417968080539, 1.4386320553661633, 1], [-2.0166469000233413, -0.12236212135416753, 1], [-1.5428857863742746, 1.6395250707132116, 1], [-1.9491041109786618, 0.7043809221523516, 1], [-1.0991931205585597, -1.3547795730927457, 1], [-3.3468167079620157, -3.083349863028997, 1], [-1.4167420293616817, -1.250583967129359, 1], [-2.3208471127769434, 3.42272269405798, 1], [-1.2290910552777952, -2.042402155845269, 1], [-2.591119027450392, -1.4567442051773023, 1], [-2.6464073375410253, -0.7674622491902292, 1], [-0.15337287638122987, -1.330478252621237, 1], [-1.3030336514264167, -0.06048148958075131, 1], [-1.8504815714735647, 2.898935202422041, 1], [-0.8132328140379816, -1.288955651488098, 1], [-1.0621792893491195, -0.2481472844903763, 1], [-1.196559454138142, -1.1853566798593753, 1], [-2.996935282004089, 0.2857148719990113, 1], [-2.6031984384631683, 2.0409392293072406, 1], [-1.5780948144365075, -0.8462349627383353, 1], [-1.816388404189842, 0.04502351606710421, 1], [-2.3256002163440597, -0.14409495477832063, 1], [-1.0807501147093004, 0.013209178405663127, 1], [0.8750938312946315, 0.4034302678820234, 1], [-2.9004937614795328, -0.3216117040748568, 1], [-3.342799204071513, -1.136977066023454, 1], [-3.5725747516573803, 2.192335521818208, 1], [0.5993964428518588, -0.04275487216011345, 1], [-2.164677451957254, 2.2392098796939996, 1], [-2.619061481162773, -1.7698307321638587, 1], [-3.209975590198013, 0.7092973441936908, 1], [-1.6820276081587804, 1.0208434104469255, 1], [-3.586459858726821, 2.142805514281176, 1], [-1.5562786780749314, 0.19723627569547378, 1], [-2.1699637476676146, -1.5435606932156916, 1], [-1.4854834096634106, 1.3688840569498117, 1], [-1.2334497650352687, -0.42810907016445826, 1], [-2.6421538300903284, -1.9974657027053484, 1], [-3.2844665394677968, 1.1923808246168874, 1], [-2.6322240450555676, -0.287912718061372, 1], [-2.203881737898663, 0.7043381033943146, 1], [-0.4957871575835118, 1.870276286706099, 1], [-2.9543271683694297, 1.5208661201266933, 1], [-3.604567881384009, 0.7790357409013815, 1], [-3.282863120585333, 1.80384111633762, 1], [6.136225045515941, 4.845285611496297, 2], [3.358601241933496, 3.751019934059549, 2], [2.9651806452726484, 3.6822567615252, 2], [4.665491231926977, 2.3405051084682, 2], [4.604905521458402, 5.384261022221785, 2], [4.166019184020129, 4.80837886147176, 2], [2.726847687133478, 3.8029862212398635, 2], [3.083429873694617, 3.8685653308881798, 2], [3.497280862907807, 2.9506819966389277, 2], [5.30088056551321, 4.6871101039669005, 2], [5.261642912047374, 4.043651240987026, 2], [5.053880665840948, 4.726484408898449, 2], [3.3972252940428445, 5.617149302579655, 2], [2.6533319152578354, 5.09147458811136, 2], [3.6109221872897788, 3.7523103442746994, 2], [3.8338886252094664, 4.347682114015021, 2], [3.4946270488405164, 3.6105050404937677, 2], [5.069119303815527, 4.264150799069864, 2], [5.2555452121963695, 4.876984638429942, 2], [4.402153267341136, 4.634952385707767, 2], [5.148943859631081, 2.6311886560593436, 2], [5.103592669048636, 6.30408075741299, 2], [4.606991707535205, 3.5275197733485495, 2], [4.350032113990249, 4.707696775095096, 2], [3.957012458207807, 3.3811425030612883, 2], [1.9354884616216883, 4.029993226729789, 2], [4.083362776456962, 4.750228926231906, 2], [4.270791491798424, 4.535809241272678, 2], [2.5527223778798636, 6.493607748081107, 2], [4.157141466979694, 3.580783763138573, 2], [5.2986441863874845, 4.943179223643645, 2], [4.666000898797912, 5.267115372176956, 2], [3.220805106984178, 3.6090911452898085, 2], [3.542165228657921, 2.9810011102630414, 2], [4.779392227428427, 2.3226977512156877, 2], [3.9267710702505663, 5.168810802001428, 2], [4.76514198141676, 3.8767860619446477, 2], [4.290199151859662, 5.0955983081795315, 2], [2.776164727808217, 3.524784380252329, 2], [4.207963718422087, 5.04274126473347, 2], [2.9782653914916417, 2.915201988049066, 2], [4.375699437400853, 2.5892950225439693, 2], [4.426970949858231, 3.9419052108087547, 2], [3.7708897158379004, 2.4944281040222593, 2], [3.9104920101067338, 2.676229432347111, 2], [6.096789524724665, 4.184960835913332, 2], [4.202740515544135, 5.187661389794847, 2], [2.5655620176720935, 3.3924447125717547, 2], [4.9242882792072, 3.2383677092167593, 2], [1.8851344601152946, 5.228459904906094, 2], [2.6676516030896726, 3.7563817279026774, 2], [2.7107246512484533, 5.59263772985132, 2], [3.2414058080462396, 4.675440322223491, 2], [4.049289535652809, 3.4653285048635665, 2], [4.0249287260083335, 2.460110034110473, 2], [5.681966284429088, 3.0282332220845847, 2], [4.146700239681935, 4.60966668621533, 2], [3.772737721339217, 4.720389099276566, 2], [4.610543309440227, 3.683317237199498, 2], [3.600288307944474, 2.431281319237196, 2], [2.890026511690566, 3.1437426794200065, 2], [2.6331310807009083, 1.6993741527422488, 2], [4.1725064709712765, 3.4732330009523205, 2], [3.4591273612186004, 2.984414700652448, 2], [3.573051669651404, 4.65614350101775, 2], [4.959576830854408, 3.349687965447966, 2], [4.64304855128699, 2.7974921823469137, 2], [2.499890584424867, 2.44356436150654, 2], [2.507762812760115, 4.346572718295065, 2], [5.125884178959594, 2.641351118529748, 2], [3.279502854847178, 4.24734080086658, 2], [6.514027657268711, 3.404325725254695, 2], [4.351224264199084, 4.53888338569196, 2], [4.039344831104019, 4.0945230869911535, 2], [3.9454506894575965, 4.6423194842654425, 2], [3.855171947877163, 2.4554668009770473, 2], [2.599242104730057, 5.225970289392297, 2], [2.3434950546651807, 4.244392286279733, 2], [5.039811532048955, 3.4492772232871287, 2], [3.844485404908035, 3.9451335474294202, 2], [3.7015803972836383, 4.499658340032937, 2], [5.068736301825577, 3.2939393796125813, 2], [3.8323560023083174, 4.761164540431628, 2], [1.5176056286959994, 4.04983340708768, 2], [5.313558932903716, 3.544020995885696, 2], [3.290554824821844, 2.727150026124553, 2], [3.233575623793614, 5.8498288669829375, 2], [2.9370459398694457, 3.000952194606334, 2], [6.772367285795599, 5.113044854585574, 2], [3.7477787877177606, 2.8017489757362326, 2], [3.788361644651954, 1.7850341341901883, 2], [3.173492508736491, 4.46849687089367, 2], [2.893810908531677, 4.334054008192826, 2], [5.1477760205326115, 5.994932134483404, 2], [3.968102847356599, 3.289567250381807, 2], [2.6962276101021967, 2.6424896515071667, 2], [2.3733296894642253, 4.833063309080972, 2], [2.727469400399162, 5.397021656394974, 2], [3.82651114445061, 3.4416913139400913, 2], [4.374440533047921, 3.5668562165478086, 2]]\n"
     ]
    }
   ],
   "source": [
    "print final_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "np.random.shuffle(final_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[6.514027657268711, 3.404325725254695, 2], [-3.282863120585333, 1.80384111633762, 1], [-1.2956230400219464, 3.702464792606867, 0], [3.7708897158379004, 2.4944281040222593, 2], [-0.2304116223343647, 4.698521110136274, 0], [1.246837949790526, 3.7406124743448106, 0], [2.3733296894642253, 4.833063309080972, 2], [-0.34859519219417745, -0.7259618415461366, 1], [-2.591119027450392, -1.4567442051773023, 1], [-0.7376472230709924, 5.803097007130773, 0], [4.166019184020129, 4.80837886147176, 2], [-0.5552360764294617, 0.2721038194690457, 1], [3.9104920101067338, 2.676229432347111, 2], [-2.039116321391051, 2.949402971986678, 0], [-1.2206991377198881, 2.8968506738340065, 1], [3.82651114445061, 3.4416913139400913, 2], [4.604905521458402, 5.384261022221785, 2], [-2.9543271683694297, 1.5208661201266933, 1], [-2.3208471127769434, 3.42272269405798, 1], [-0.11573062205373608, 3.480755568518185, 0], [4.666000898797912, 5.267115372176956, 2], [1.755091160248367, 1.960071082702421, 0], [3.6109221872897788, 3.7523103442746994, 2], [-1.2290910552777952, -2.042402155845269, 1], [-0.6569063492545735, 4.262380962745065, 0], [4.0249287260083335, 2.460110034110473, 2], [3.8338886252094664, 4.347682114015021, 2], [-2.8744247000521312, -1.6356078277295545, 1], [4.665491231926977, 2.3405051084682, 2], [3.772737721339217, 4.720389099276566, 2], [-2.01753272907842, -0.07443334976717048, 1], [-1.24376997962011, 4.587481260319372, 0], [-3.396610989413893, -0.481025839655896, 1], [-2.182417968080539, 1.4386320553661633, 1], [4.76514198141676, 3.8767860619446477, 2], [2.507762812760115, 4.346572718295065, 2], [1.8884373234206377, 3.413593905948643, 0], [-2.6322240450555676, -0.287912718061372, 1], [1.9354884616216883, 4.029993226729789, 2], [-1.3751452747477573, 2.566603177341791, 1], [-2.9004937614795328, -0.3216117040748568, 1], [-1.8504815714735647, 2.898935202422041, 1], [-0.0797236171908722, 4.148615149643331, 0], [0.2759101397369275, 2.5995246773954275, 0], [-3.447420227297748, 2.121301402055159, 1], [-1.7048565706654364, -0.46715997941076914, 1], [0.0009144524766546955, 3.4883492002658114, 0], [-1.8585559514218042, 2.9229067964579536, 0], [-2.7964090769620955, -1.238490430319433, 1], [2.1584434273177484, 3.382361255485536, 0], [-2.8444606016435223, 0.9313394387612425, 1], [3.8323560023083174, 4.761164540431628, 2], [3.3972252940428445, 5.617149302579655, 2], [-0.23028120313149847, -2.3693570257731844, 1], [3.968102847356599, 3.289567250381807, 2], [3.290554824821844, 2.727150026124553, 2], [-2.996935282004089, 0.2857148719990113, 1], [-1.5780948144365075, -0.8462349627383353, 1], [-1.6591661687572872, -0.05823522853658854, 1], [0.8750938312946315, 0.4034302678820234, 1], [0.3616868451134729, 4.461186774668471, 0], [4.1725064709712765, 3.4732330009523205, 2], [-2.972233246107584, 2.4394352282994163, 1], [5.068736301825577, 3.2939393796125813, 2], [-2.3256002163440597, -0.14409495477832063, 1], [-2.9033226684593885, -2.3568402619862, 1], [2.893810908531677, 4.334054008192826, 2], [-1.9732767789212167, -2.657157583190173, 1], [2.0803735950949687, 3.9420024711310737, 0], [-2.3863110549753124, 0.5397482546093665, 1], [-1.1728671749338666, 2.1952494264593727, 1], [-0.6563091088989254, 5.051006292393634, 0], [2.9782653914916417, 2.915201988049066, 2], [4.873477831427014, 3.911043062660721, 0], [-3.5725747516573803, 2.192335521818208, 1], [-0.6380925086824782, 0.029563983412872433, 1], [0.42299761560666266, 4.613648175125713, 0], [3.4591273612186004, 2.984414700652448, 2], [2.7107246512484533, 5.59263772985132, 2], [0.5469478265638494, 3.533230634808844, 0], [-2.6031984384631683, 2.0409392293072406, 1], [-1.2334497650352687, -0.42810907016445826, 1], [0.642013224293857, 4.629178721936455, 0], [1.1346679230545944, 4.2453793112648475, 0], [0.24295355936942525, 3.7690670506687525, 0], [3.957012458207807, 3.3811425030612883, 2], [-0.1393500485857051, 4.169106167012477, 0], [4.049289535652809, 3.4653285048635665, 2], [-1.110290442688222, -1.384521718778183, 1], [2.5527223778798636, 6.493607748081107, 2], [-3.4734488003805533, -2.645835042677507, 1], [5.313558932903716, 3.544020995885696, 2], [1.6999871219014488, 3.0271119638873043, 0], [-1.6702576184489226, 4.682037551900045, 0], [1.4526835758186951, 4.502273717152065, 0], [-3.3468167079620157, -3.083349863028997, 1], [-1.26937928167189, 3.9366603289412025, 0], [-1.1598230851730404, 3.1633666849826403, 0], [4.9242882792072, 3.2383677092167593, 2], [-1.8803827786671055, 0.11056830248492576, 1], [-0.17410428062495628, 4.3287683736566125, 0], [0.9603014627340807, 4.06899014584137, 0], [-3.1319494772141816, 1.0415140880284999, 1], [2.1320743995469598, 5.424193031378502, 0], [3.0968048034125757, 3.6806107244228894, 0], [4.270791491798424, 4.535809241272678, 2], [3.9454506894575965, 4.6423194842654425, 2], [0.3939423800045117, 3.6943452523501166, 0], [2.6676516030896726, 3.7563817279026774, 2], [-1.0807501147093004, 0.013209178405663127, 1], [4.350032113990249, 4.707696775095096, 2], [-3.2844665394677968, 1.1923808246168874, 1], [2.6533319152578354, 5.09147458811136, 2], [0.2632863000119277, 3.6418791057870057, 0], [2.776164727808217, 3.524784380252329, 2], [-2.0938587928573447, 4.607698040503408, 0], [5.148943859631081, 2.6311886560593436, 2], [0.5252337325015385, 4.488056876575435, 0], [-2.417695799060901, -0.3054211566559618, 1], [-0.48303865121086914, 2.9512899473322833, 0], [2.499890584424867, 2.44356436150654, 2], [1.3915330382176745, 3.609610779138298, 0], [4.039344831104019, 4.0945230869911535, 2], [2.9651806452726484, 3.6822567615252, 2], [2.3434950546651807, 4.244392286279733, 2], [0.6382892112881816, 3.588033841365842, 0], [4.606991707535205, 3.5275197733485495, 2], [4.146700239681935, 4.60966668621533, 2], [-2.6421538300903284, -1.9974657027053484, 1], [0.583212279631955, 2.580110125461964, 0], [-2.3456293811887794, 1.065115954289983, 1], [0.7055929486208558, 4.310387049225572, 0], [-1.8843533072735523, -0.8419113804803363, 1], [2.727469400399162, 5.397021656394974, 2], [-3.342799204071513, -1.136977066023454, 1], [-1.964489469838514, 0.7689148007239601, 1], [-2.619061481162773, -1.7698307321638587, 1], [0.4797272101616133, 5.296738129193235, 0], [3.844485404908035, 3.9451335474294202, 2], [-1.11562789887616, 4.271336233198918, 0], [-0.8925301247280577, -0.7168706709934206, 1], [-3.426929548555545, 4.731833142875785, 0], [3.233575623793614, 5.8498288669829375, 2], [4.157141466979694, 3.580783763138573, 2], [0.03566039932018928, 4.032797732705941, 0], [4.375699437400853, 2.5892950225439693, 2], [-1.6985413150142163, -0.44526191458764275, 1], [-2.5407860290347144, -0.6199577192431571, 1], [-1.6273887432371903, 5.505535694905153, 0], [3.358601241933496, 3.751019934059549, 2], [1.8851344601152946, 5.228459904906094, 2], [-1.5562786780749314, 0.19723627569547378, 1], [-1.6407126032585873, -1.732069813433546, 1], [3.497280862907807, 2.9506819966389277, 2], [3.279502854847178, 4.24734080086658, 2], [3.855171947877163, 2.4554668009770473, 2], [4.207963718422087, 5.04274126473347, 2], [-1.5533586164818436, -2.0185339460170533, 1], [-1.3030336514264167, -0.06048148958075131, 1], [-1.1067780728998662, 3.9078285367195775, 0], [-0.14490604294007797, 4.061320166617641, 0], [-1.3993130309070427, 4.026755341718367, 0], [4.374440533047921, 3.5668562165478086, 2], [-1.0490068844169445, 2.951040566785435, 1], [2.6331310807009083, 1.6993741527422488, 2], [-0.4957871575835118, 1.870276286706099, 1], [-2.3541297278981905, 4.617919747329587, 0], [-2.7716362876242373, 2.1011645644950923, 1], [-1.276376069262216, 3.6380032667868294, 0], [5.261642912047374, 4.043651240987026, 2], [-4.201020759276739, -1.5002038235241208, 1], [-1.0254931176690707, 3.6805312273656683, 0], [-0.7864902253465736, 4.199606423310847, 0], [0.3630595564901449, 3.8193121848730667, 0], [-2.203881737898663, 0.7043381033943146, 1], [-3.33873248655075, -0.08069373810368652, 1], [0.31344772269330834, -1.494449835471386, 1], [-2.0166469000233413, -0.12236212135416753, 1], [-3.604567881384009, 0.7790357409013815, 1], [-0.8946735160243385, 4.772101522859985, 0], [-2.6464073375410253, -0.7674622491902292, 1], [5.681966284429088, 3.0282332220845847, 2], [-1.2618910141262774, -0.5864298950798029, 1], [3.5593897975906197, 3.736527009372364, 0], [0.5121696913796528, 4.71794986611183, 0], [2.6962276101021967, 2.6424896515071667, 2], [-3.21317704887337, 4.267090754033438, 0], [2.0604561909115473, 4.936152528359206, 0], [5.039811532048955, 3.4492772232871287, 2], [3.7015803972836383, 4.499658340032937, 2], [3.542165228657921, 2.9810011102630414, 2], [-0.15337287638122987, -1.330478252621237, 1], [4.351224264199084, 4.53888338569196, 2], [4.290199151859662, 5.0955983081795315, 2], [0.5993964428518588, -0.04275487216011345, 1], [0.859832381065888, 4.062416658541541, 0], [-1.379386248249244, 0.77658975429236, 1], [-0.8820409969261933, -0.5454598952166543, 1], [2.726847687133478, 3.8029862212398635, 2], [-1.0991931205585597, -1.3547795730927457, 1], [4.779392227428427, 2.3226977512156877, 2], [4.402153267341136, 4.634952385707767, 2], [4.202740515544135, 5.187661389794847, 2], [2.890026511690566, 3.1437426794200065, 2], [-2.080134463932432, 4.6618379820484686, 0], [-0.5158730553503137, 4.3880572535000795, 0], [5.053880665840948, 4.726484408898449, 2], [-2.1749244254585474, 4.680191769512619, 0], [-2.7642373882075035, 3.715354172088658, 0], [5.069119303815527, 4.264150799069864, 2], [3.9267710702505663, 5.168810802001428, 2], [0.6264737578152862, 3.741794266275358, 0], [3.573051669651404, 4.65614350101775, 2], [-1.963552865256177, -1.1011084371028548, 1], [-3.1302980001607676, -0.3373076001846845, 1], [1.9159750400662772, 4.322823105875789, 0], [-1.5428857863742746, 1.6395250707132116, 1], [3.083429873694617, 3.8685653308881798, 2], [4.426970949858231, 3.9419052108087547, 2], [2.599242104730057, 5.225970289392297, 2], [3.600288307944474, 2.431281319237196, 2], [4.610543309440227, 3.683317237199498, 2], [-0.3274045137736968, 4.339044033478316, 0], [0.5626346758673177, 0.869462093590307, 1], [-0.24222779307396392, 3.392477267161001, 0], [-1.227059694348927, 4.312979802762684, 0], [5.30088056551321, 4.6871101039669005, 2], [-2.4558274876205752, 0.268352523555578, 1], [0.718229805430768, 4.310576411147394, 0], [-3.968262247197826, 4.100680107599374, 0], [-0.8936085001831094, 4.973320112542332, 0], [-3.209975590198013, 0.7092973441936908, 1], [-2.1699637476676146, -1.5435606932156916, 1], [-1.0448773708191637, 0.1889974603839339, 1], [1.5176056286959994, 4.04983340708768, 2], [-2.5517564445509646, -1.2321718169022546, 1], [-2.14918281232061, -0.05909182563279221, 1], [-0.20741064108600205, 3.868701095470327, 0], [2.8456608218091612, 3.274257814811507, 0], [-1.4167420293616817, -1.250583967129359, 1], [-0.7552911522794037, -1.366119654210927, 1], [2.289116415364021, 3.9870101671111535, 0], [-1.7376573738498051, 4.0222670151939015, 0], [-1.9341954949245357, -3.2471436673481913, 1], [5.2555452121963695, 4.876984638429942, 2], [-1.0621792893491195, -0.2481472844903763, 1], [-0.9363164618457444, 3.938951342152362, 0], [3.220805106984178, 3.6090911452898085, 2], [6.136225045515941, 4.845285611496297, 2], [5.1477760205326115, 5.994932134483404, 2], [-1.196559454138142, -1.1853566798593753, 1], [2.5263897404541424, 4.2863931066916505, 0], [-0.5578559773113465, 4.505703799310029, 0], [-2.7348217180204775, 2.522143850147877, 0], [-0.4956626365384753, -0.5175352324112511, 1], [2.5655620176720935, 3.3924447125717547, 2], [-2.946973346565818, 4.2007797192817495, 0], [2.253177787194711, 5.114644685392081, 0], [-1.229676620128568, 4.378887070121326, 0], [-1.9491041109786618, 0.7043809221523516, 1], [0.695391819473179, 3.215383688084266, 0], [-0.48399109180025357, 4.016821162942991, 0], [4.64304855128699, 2.7974921823469137, 2], [3.173492508736491, 4.46849687089367, 2], [2.9370459398694457, 3.000952194606334, 2], [3.788361644651954, 1.7850341341901883, 2], [5.2986441863874845, 4.943179223643645, 2], [-3.0806583652946378, 4.313036328113317, 0], [-0.6467954979182303, 3.312358884754359, 0], [-4.408400952016439, 1.7771350320195927, 1], [6.096789524724665, 4.184960835913332, 2], [-3.586459858726821, 2.142805514281176, 1], [-0.4569849483877815, -2.0272409607900634, 1], [-2.2969780465391754, 3.606601458530838, 0], [6.772367285795599, 5.113044854585574, 2], [-0.7762488068320517, 1.6847176334540306, 1], [-1.816388404189842, 0.04502351606710421, 1], [3.2414058080462396, 4.675440322223491, 2], [3.4946270488405164, 3.6105050404937677, 2], [-1.6023973289952056, 5.098031584492405, 0], [-2.1057356340089886, 4.083206923044556, 0], [2.2344063935391776, 4.463233724669758, 0], [3.7477787877177606, 2.8017489757362326, 2], [-2.164677451957254, 2.2392098796939996, 1], [-1.4704502157593242, 3.9452476509906944, 0], [-1.4854834096634106, 1.3688840569498117, 1], [0.24511472118101468, 3.1712311787937524, 0], [-1.6820276081587804, 1.0208434104469255, 1], [4.959576830854408, 3.349687965447966, 2], [4.083362776456962, 4.750228926231906, 2], [-0.4329048612781419, 3.5574305717396104, 0], [-2.8908749024996423, -0.1546427282985843, 1], [-1.8125482079006536, 4.696281039511931, 0], [5.125884178959594, 2.641351118529748, 2], [5.103592669048636, 6.30408075741299, 2], [-2.253314351369113, 3.953951563079367, 0], [-1.552126868409797, 0.30242003040005017, 1], [2.383292899572844, 5.017411703106532, 0], [-0.8132328140379816, -1.288955651488098, 1], [2.290300231330779, 3.3096653562533285, 0]]\n"
     ]
    }
   ],
   "source": [
    "print final_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6.514028</td>\n",
       "      <td>3.404326</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-3.282863</td>\n",
       "      <td>1.803841</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-1.295623</td>\n",
       "      <td>3.702465</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3.770890</td>\n",
       "      <td>2.494428</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-0.230412</td>\n",
       "      <td>4.698521</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         x1        x2  label\n",
       "0  6.514028  3.404326      2\n",
       "1 -3.282863  1.803841      1\n",
       "2 -1.295623  3.702465      0\n",
       "3  3.770890  2.494428      2\n",
       "4 -0.230412  4.698521      0"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_data = pd.DataFrame(final_data, columns = ['x1', 'x2', 'label'])\n",
    "final_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X = final_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.43120952  0.3964526   0.17233788]\n"
     ]
    }
   ],
   "source": [
    "# N = 300 d = 2\n",
    "\n",
    "pis = np.random.random(3)\n",
    "pis /= pis.sum()\n",
    "print pis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-0.14637588,  4.04830935],\n",
       "       [-1.93894147, -0.01201856],\n",
       "       [ 3.92480039,  3.9630975 ]])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mus = np.array([X[X['label'] == num].mean(axis = 0)[:-1] for num in range(3)])\n",
    "mus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[ 1.,  0.],\n",
       "        [ 0.,  1.]],\n",
       "\n",
       "       [[ 1.,  0.],\n",
       "        [ 0.,  1.]],\n",
       "\n",
       "       [[ 1.,  0.],\n",
       "        [ 0.,  1.]]])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sigmas = np.array([np.eye(2)] * 3)\n",
    "sigmas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X_hat = X.iloc[:,:-1].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[  6.51402766e+00   3.40432573e+00]\n",
      " [ -3.28286312e+00   1.80384112e+00]\n",
      " [ -1.29562304e+00   3.70246479e+00]\n",
      " [  3.77088972e+00   2.49442810e+00]\n",
      " [ -2.30411622e-01   4.69852111e+00]\n",
      " [  1.24683795e+00   3.74061247e+00]\n",
      " [  2.37332969e+00   4.83306331e+00]\n",
      " [ -3.48595192e-01  -7.25961842e-01]\n",
      " [ -2.59111903e+00  -1.45674421e+00]\n",
      " [ -7.37647223e-01   5.80309701e+00]\n",
      " [  4.16601918e+00   4.80837886e+00]\n",
      " [ -5.55236076e-01   2.72103819e-01]\n",
      " [  3.91049201e+00   2.67622943e+00]\n",
      " [ -2.03911632e+00   2.94940297e+00]\n",
      " [ -1.22069914e+00   2.89685067e+00]\n",
      " [  3.82651114e+00   3.44169131e+00]\n",
      " [  4.60490552e+00   5.38426102e+00]\n",
      " [ -2.95432717e+00   1.52086612e+00]\n",
      " [ -2.32084711e+00   3.42272269e+00]\n",
      " [ -1.15730622e-01   3.48075557e+00]\n",
      " [  4.66600090e+00   5.26711537e+00]\n",
      " [  1.75509116e+00   1.96007108e+00]\n",
      " [  3.61092219e+00   3.75231034e+00]\n",
      " [ -1.22909106e+00  -2.04240216e+00]\n",
      " [ -6.56906349e-01   4.26238096e+00]\n",
      " [  4.02492873e+00   2.46011003e+00]\n",
      " [  3.83388863e+00   4.34768211e+00]\n",
      " [ -2.87442470e+00  -1.63560783e+00]\n",
      " [  4.66549123e+00   2.34050511e+00]\n",
      " [  3.77273772e+00   4.72038910e+00]\n",
      " [ -2.01753273e+00  -7.44333498e-02]\n",
      " [ -1.24376998e+00   4.58748126e+00]\n",
      " [ -3.39661099e+00  -4.81025840e-01]\n",
      " [ -2.18241797e+00   1.43863206e+00]\n",
      " [  4.76514198e+00   3.87678606e+00]\n",
      " [  2.50776281e+00   4.34657272e+00]\n",
      " [  1.88843732e+00   3.41359391e+00]\n",
      " [ -2.63222405e+00  -2.87912718e-01]\n",
      " [  1.93548846e+00   4.02999323e+00]\n",
      " [ -1.37514527e+00   2.56660318e+00]\n",
      " [ -2.90049376e+00  -3.21611704e-01]\n",
      " [ -1.85048157e+00   2.89893520e+00]\n",
      " [ -7.97236172e-02   4.14861515e+00]\n",
      " [  2.75910140e-01   2.59952468e+00]\n",
      " [ -3.44742023e+00   2.12130140e+00]\n",
      " [ -1.70485657e+00  -4.67159979e-01]\n",
      " [  9.14452477e-04   3.48834920e+00]\n",
      " [ -1.85855595e+00   2.92290680e+00]\n",
      " [ -2.79640908e+00  -1.23849043e+00]\n",
      " [  2.15844343e+00   3.38236126e+00]\n",
      " [ -2.84446060e+00   9.31339439e-01]\n",
      " [  3.83235600e+00   4.76116454e+00]\n",
      " [  3.39722529e+00   5.61714930e+00]\n",
      " [ -2.30281203e-01  -2.36935703e+00]\n",
      " [  3.96810285e+00   3.28956725e+00]\n",
      " [  3.29055482e+00   2.72715003e+00]\n",
      " [ -2.99693528e+00   2.85714872e-01]\n",
      " [ -1.57809481e+00  -8.46234963e-01]\n",
      " [ -1.65916617e+00  -5.82352285e-02]\n",
      " [  8.75093831e-01   4.03430268e-01]\n",
      " [  3.61686845e-01   4.46118677e+00]\n",
      " [  4.17250647e+00   3.47323300e+00]\n",
      " [ -2.97223325e+00   2.43943523e+00]\n",
      " [  5.06873630e+00   3.29393938e+00]\n",
      " [ -2.32560022e+00  -1.44094955e-01]\n",
      " [ -2.90332267e+00  -2.35684026e+00]\n",
      " [  2.89381091e+00   4.33405401e+00]\n",
      " [ -1.97327678e+00  -2.65715758e+00]\n",
      " [  2.08037360e+00   3.94200247e+00]\n",
      " [ -2.38631105e+00   5.39748255e-01]\n",
      " [ -1.17286717e+00   2.19524943e+00]\n",
      " [ -6.56309109e-01   5.05100629e+00]\n",
      " [  2.97826539e+00   2.91520199e+00]\n",
      " [  4.87347783e+00   3.91104306e+00]\n",
      " [ -3.57257475e+00   2.19233552e+00]\n",
      " [ -6.38092509e-01   2.95639834e-02]\n",
      " [  4.22997616e-01   4.61364818e+00]\n",
      " [  3.45912736e+00   2.98441470e+00]\n",
      " [  2.71072465e+00   5.59263773e+00]\n",
      " [  5.46947827e-01   3.53323063e+00]\n",
      " [ -2.60319844e+00   2.04093923e+00]\n",
      " [ -1.23344977e+00  -4.28109070e-01]\n",
      " [  6.42013224e-01   4.62917872e+00]\n",
      " [  1.13466792e+00   4.24537931e+00]\n",
      " [  2.42953559e-01   3.76906705e+00]\n",
      " [  3.95701246e+00   3.38114250e+00]\n",
      " [ -1.39350049e-01   4.16910617e+00]\n",
      " [  4.04928954e+00   3.46532850e+00]\n",
      " [ -1.11029044e+00  -1.38452172e+00]\n",
      " [  2.55272238e+00   6.49360775e+00]\n",
      " [ -3.47344880e+00  -2.64583504e+00]\n",
      " [  5.31355893e+00   3.54402100e+00]\n",
      " [  1.69998712e+00   3.02711196e+00]\n",
      " [ -1.67025762e+00   4.68203755e+00]\n",
      " [  1.45268358e+00   4.50227372e+00]\n",
      " [ -3.34681671e+00  -3.08334986e+00]\n",
      " [ -1.26937928e+00   3.93666033e+00]\n",
      " [ -1.15982309e+00   3.16336668e+00]\n",
      " [  4.92428828e+00   3.23836771e+00]\n",
      " [ -1.88038278e+00   1.10568302e-01]\n",
      " [ -1.74104281e-01   4.32876837e+00]\n",
      " [  9.60301463e-01   4.06899015e+00]\n",
      " [ -3.13194948e+00   1.04151409e+00]\n",
      " [  2.13207440e+00   5.42419303e+00]\n",
      " [  3.09680480e+00   3.68061072e+00]\n",
      " [  4.27079149e+00   4.53580924e+00]\n",
      " [  3.94545069e+00   4.64231948e+00]\n",
      " [  3.93942380e-01   3.69434525e+00]\n",
      " [  2.66765160e+00   3.75638173e+00]\n",
      " [ -1.08075011e+00   1.32091784e-02]\n",
      " [  4.35003211e+00   4.70769678e+00]\n",
      " [ -3.28446654e+00   1.19238082e+00]\n",
      " [  2.65333192e+00   5.09147459e+00]\n",
      " [  2.63286300e-01   3.64187911e+00]\n",
      " [  2.77616473e+00   3.52478438e+00]\n",
      " [ -2.09385879e+00   4.60769804e+00]\n",
      " [  5.14894386e+00   2.63118866e+00]\n",
      " [  5.25233733e-01   4.48805688e+00]\n",
      " [ -2.41769580e+00  -3.05421157e-01]\n",
      " [ -4.83038651e-01   2.95128995e+00]\n",
      " [  2.49989058e+00   2.44356436e+00]\n",
      " [  1.39153304e+00   3.60961078e+00]\n",
      " [  4.03934483e+00   4.09452309e+00]\n",
      " [  2.96518065e+00   3.68225676e+00]\n",
      " [  2.34349505e+00   4.24439229e+00]\n",
      " [  6.38289211e-01   3.58803384e+00]\n",
      " [  4.60699171e+00   3.52751977e+00]\n",
      " [  4.14670024e+00   4.60966669e+00]\n",
      " [ -2.64215383e+00  -1.99746570e+00]\n",
      " [  5.83212280e-01   2.58011013e+00]\n",
      " [ -2.34562938e+00   1.06511595e+00]\n",
      " [  7.05592949e-01   4.31038705e+00]\n",
      " [ -1.88435331e+00  -8.41911380e-01]\n",
      " [  2.72746940e+00   5.39702166e+00]\n",
      " [ -3.34279920e+00  -1.13697707e+00]\n",
      " [ -1.96448947e+00   7.68914801e-01]\n",
      " [ -2.61906148e+00  -1.76983073e+00]\n",
      " [  4.79727210e-01   5.29673813e+00]\n",
      " [  3.84448540e+00   3.94513355e+00]\n",
      " [ -1.11562790e+00   4.27133623e+00]\n",
      " [ -8.92530125e-01  -7.16870671e-01]\n",
      " [ -3.42692955e+00   4.73183314e+00]\n",
      " [  3.23357562e+00   5.84982887e+00]\n",
      " [  4.15714147e+00   3.58078376e+00]\n",
      " [  3.56603993e-02   4.03279773e+00]\n",
      " [  4.37569944e+00   2.58929502e+00]\n",
      " [ -1.69854132e+00  -4.45261915e-01]\n",
      " [ -2.54078603e+00  -6.19957719e-01]\n",
      " [ -1.62738874e+00   5.50553569e+00]\n",
      " [  3.35860124e+00   3.75101993e+00]\n",
      " [  1.88513446e+00   5.22845990e+00]\n",
      " [ -1.55627868e+00   1.97236276e-01]\n",
      " [ -1.64071260e+00  -1.73206981e+00]\n",
      " [  3.49728086e+00   2.95068200e+00]\n",
      " [  3.27950285e+00   4.24734080e+00]\n",
      " [  3.85517195e+00   2.45546680e+00]\n",
      " [  4.20796372e+00   5.04274126e+00]\n",
      " [ -1.55335862e+00  -2.01853395e+00]\n",
      " [ -1.30303365e+00  -6.04814896e-02]\n",
      " [ -1.10677807e+00   3.90782854e+00]\n",
      " [ -1.44906043e-01   4.06132017e+00]\n",
      " [ -1.39931303e+00   4.02675534e+00]\n",
      " [  4.37444053e+00   3.56685622e+00]\n",
      " [ -1.04900688e+00   2.95104057e+00]\n",
      " [  2.63313108e+00   1.69937415e+00]\n",
      " [ -4.95787158e-01   1.87027629e+00]\n",
      " [ -2.35412973e+00   4.61791975e+00]\n",
      " [ -2.77163629e+00   2.10116456e+00]\n",
      " [ -1.27637607e+00   3.63800327e+00]\n",
      " [  5.26164291e+00   4.04365124e+00]\n",
      " [ -4.20102076e+00  -1.50020382e+00]\n",
      " [ -1.02549312e+00   3.68053123e+00]\n",
      " [ -7.86490225e-01   4.19960642e+00]\n",
      " [  3.63059556e-01   3.81931218e+00]\n",
      " [ -2.20388174e+00   7.04338103e-01]\n",
      " [ -3.33873249e+00  -8.06937381e-02]\n",
      " [  3.13447723e-01  -1.49444984e+00]\n",
      " [ -2.01664690e+00  -1.22362121e-01]\n",
      " [ -3.60456788e+00   7.79035741e-01]\n",
      " [ -8.94673516e-01   4.77210152e+00]\n",
      " [ -2.64640734e+00  -7.67462249e-01]\n",
      " [  5.68196628e+00   3.02823322e+00]\n",
      " [ -1.26189101e+00  -5.86429895e-01]\n",
      " [  3.55938980e+00   3.73652701e+00]\n",
      " [  5.12169691e-01   4.71794987e+00]\n",
      " [  2.69622761e+00   2.64248965e+00]\n",
      " [ -3.21317705e+00   4.26709075e+00]\n",
      " [  2.06045619e+00   4.93615253e+00]\n",
      " [  5.03981153e+00   3.44927722e+00]\n",
      " [  3.70158040e+00   4.49965834e+00]\n",
      " [  3.54216523e+00   2.98100111e+00]\n",
      " [ -1.53372876e-01  -1.33047825e+00]\n",
      " [  4.35122426e+00   4.53888339e+00]\n",
      " [  4.29019915e+00   5.09559831e+00]\n",
      " [  5.99396443e-01  -4.27548722e-02]\n",
      " [  8.59832381e-01   4.06241666e+00]\n",
      " [ -1.37938625e+00   7.76589754e-01]\n",
      " [ -8.82040997e-01  -5.45459895e-01]\n",
      " [  2.72684769e+00   3.80298622e+00]\n",
      " [ -1.09919312e+00  -1.35477957e+00]\n",
      " [  4.77939223e+00   2.32269775e+00]\n",
      " [  4.40215327e+00   4.63495239e+00]\n",
      " [  4.20274052e+00   5.18766139e+00]\n",
      " [  2.89002651e+00   3.14374268e+00]\n",
      " [ -2.08013446e+00   4.66183798e+00]\n",
      " [ -5.15873055e-01   4.38805725e+00]\n",
      " [  5.05388067e+00   4.72648441e+00]\n",
      " [ -2.17492443e+00   4.68019177e+00]\n",
      " [ -2.76423739e+00   3.71535417e+00]\n",
      " [  5.06911930e+00   4.26415080e+00]\n",
      " [  3.92677107e+00   5.16881080e+00]\n",
      " [  6.26473758e-01   3.74179427e+00]\n",
      " [  3.57305167e+00   4.65614350e+00]\n",
      " [ -1.96355287e+00  -1.10110844e+00]\n",
      " [ -3.13029800e+00  -3.37307600e-01]\n",
      " [  1.91597504e+00   4.32282311e+00]\n",
      " [ -1.54288579e+00   1.63952507e+00]\n",
      " [  3.08342987e+00   3.86856533e+00]\n",
      " [  4.42697095e+00   3.94190521e+00]\n",
      " [  2.59924210e+00   5.22597029e+00]\n",
      " [  3.60028831e+00   2.43128132e+00]\n",
      " [  4.61054331e+00   3.68331724e+00]\n",
      " [ -3.27404514e-01   4.33904403e+00]\n",
      " [  5.62634676e-01   8.69462094e-01]\n",
      " [ -2.42227793e-01   3.39247727e+00]\n",
      " [ -1.22705969e+00   4.31297980e+00]\n",
      " [  5.30088057e+00   4.68711010e+00]\n",
      " [ -2.45582749e+00   2.68352524e-01]\n",
      " [  7.18229805e-01   4.31057641e+00]\n",
      " [ -3.96826225e+00   4.10068011e+00]\n",
      " [ -8.93608500e-01   4.97332011e+00]\n",
      " [ -3.20997559e+00   7.09297344e-01]\n",
      " [ -2.16996375e+00  -1.54356069e+00]\n",
      " [ -1.04487737e+00   1.88997460e-01]\n",
      " [  1.51760563e+00   4.04983341e+00]\n",
      " [ -2.55175644e+00  -1.23217182e+00]\n",
      " [ -2.14918281e+00  -5.90918256e-02]\n",
      " [ -2.07410641e-01   3.86870110e+00]\n",
      " [  2.84566082e+00   3.27425781e+00]\n",
      " [ -1.41674203e+00  -1.25058397e+00]\n",
      " [ -7.55291152e-01  -1.36611965e+00]\n",
      " [  2.28911642e+00   3.98701017e+00]\n",
      " [ -1.73765737e+00   4.02226702e+00]\n",
      " [ -1.93419549e+00  -3.24714367e+00]\n",
      " [  5.25554521e+00   4.87698464e+00]\n",
      " [ -1.06217929e+00  -2.48147284e-01]\n",
      " [ -9.36316462e-01   3.93895134e+00]\n",
      " [  3.22080511e+00   3.60909115e+00]\n",
      " [  6.13622505e+00   4.84528561e+00]\n",
      " [  5.14777602e+00   5.99493213e+00]\n",
      " [ -1.19655945e+00  -1.18535668e+00]\n",
      " [  2.52638974e+00   4.28639311e+00]\n",
      " [ -5.57855977e-01   4.50570380e+00]\n",
      " [ -2.73482172e+00   2.52214385e+00]\n",
      " [ -4.95662637e-01  -5.17535232e-01]\n",
      " [  2.56556202e+00   3.39244471e+00]\n",
      " [ -2.94697335e+00   4.20077972e+00]\n",
      " [  2.25317779e+00   5.11464469e+00]\n",
      " [ -1.22967662e+00   4.37888707e+00]\n",
      " [ -1.94910411e+00   7.04380922e-01]\n",
      " [  6.95391819e-01   3.21538369e+00]\n",
      " [ -4.83991092e-01   4.01682116e+00]\n",
      " [  4.64304855e+00   2.79749218e+00]\n",
      " [  3.17349251e+00   4.46849687e+00]\n",
      " [  2.93704594e+00   3.00095219e+00]\n",
      " [  3.78836164e+00   1.78503413e+00]\n",
      " [  5.29864419e+00   4.94317922e+00]\n",
      " [ -3.08065837e+00   4.31303633e+00]\n",
      " [ -6.46795498e-01   3.31235888e+00]\n",
      " [ -4.40840095e+00   1.77713503e+00]\n",
      " [  6.09678952e+00   4.18496084e+00]\n",
      " [ -3.58645986e+00   2.14280551e+00]\n",
      " [ -4.56984948e-01  -2.02724096e+00]\n",
      " [ -2.29697805e+00   3.60660146e+00]\n",
      " [  6.77236729e+00   5.11304485e+00]\n",
      " [ -7.76248807e-01   1.68471763e+00]\n",
      " [ -1.81638840e+00   4.50235161e-02]\n",
      " [  3.24140581e+00   4.67544032e+00]\n",
      " [  3.49462705e+00   3.61050504e+00]\n",
      " [ -1.60239733e+00   5.09803158e+00]\n",
      " [ -2.10573563e+00   4.08320692e+00]\n",
      " [  2.23440639e+00   4.46323372e+00]\n",
      " [  3.74777879e+00   2.80174898e+00]\n",
      " [ -2.16467745e+00   2.23920988e+00]\n",
      " [ -1.47045022e+00   3.94524765e+00]\n",
      " [ -1.48548341e+00   1.36888406e+00]\n",
      " [  2.45114721e-01   3.17123118e+00]\n",
      " [ -1.68202761e+00   1.02084341e+00]\n",
      " [  4.95957683e+00   3.34968797e+00]\n",
      " [  4.08336278e+00   4.75022893e+00]\n",
      " [ -4.32904861e-01   3.55743057e+00]\n",
      " [ -2.89087490e+00  -1.54642728e-01]\n",
      " [ -1.81254821e+00   4.69628104e+00]\n",
      " [  5.12588418e+00   2.64135112e+00]\n",
      " [  5.10359267e+00   6.30408076e+00]\n",
      " [ -2.25331435e+00   3.95395156e+00]\n",
      " [ -1.55212687e+00   3.02420030e-01]\n",
      " [  2.38329290e+00   5.01741170e+00]\n",
      " [ -8.13232814e-01  -1.28895565e+00]\n",
      " [  2.29030023e+00   3.30966536e+00]]\n"
     ]
    }
   ],
   "source": [
    "print X_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0      2\n",
      "1      1\n",
      "2      0\n",
      "3      2\n",
      "4      0\n",
      "5      0\n",
      "6      2\n",
      "7      1\n",
      "8      1\n",
      "9      0\n",
      "10     2\n",
      "11     1\n",
      "12     2\n",
      "13     0\n",
      "14     1\n",
      "15     2\n",
      "16     2\n",
      "17     1\n",
      "18     1\n",
      "19     0\n",
      "20     2\n",
      "21     0\n",
      "22     2\n",
      "23     1\n",
      "24     0\n",
      "25     2\n",
      "26     2\n",
      "27     1\n",
      "28     2\n",
      "29     2\n",
      "      ..\n",
      "270    2\n",
      "271    1\n",
      "272    1\n",
      "273    0\n",
      "274    2\n",
      "275    1\n",
      "276    1\n",
      "277    2\n",
      "278    2\n",
      "279    0\n",
      "280    0\n",
      "281    0\n",
      "282    2\n",
      "283    1\n",
      "284    0\n",
      "285    1\n",
      "286    0\n",
      "287    1\n",
      "288    2\n",
      "289    2\n",
      "290    0\n",
      "291    1\n",
      "292    0\n",
      "293    2\n",
      "294    2\n",
      "295    0\n",
      "296    1\n",
      "297    0\n",
      "298    1\n",
      "299    0\n",
      "Name: label, Length: 300, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "print X.iloc[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# expectation step\n",
    "from scipy.stats import multivariate_normal\n",
    "\n",
    "gammas  = np.array([value[0]*multivariate_normal(mean = value[1], cov = value[2]).pdf(X_hat) for value in zip(pis, mus, sigmas)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(gammas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "300"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(multivariate_normal(mean = mus[0], cov = sigmas[0]).pdf(X_hat))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "300"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(gammas.sum(axis = 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "gammas = gammas / gammas.sum(axis = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1, 2],\n",
       "       [1, 2],\n",
       "       [1, 1]])"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t1 = np.array([[1,2],[10,40],[5,6]])\n",
    "\n",
    "t2 = np.array([1, 2, 5]).reshape(-1, 1)\n",
    "\n",
    "t1/t2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
